A second try at nowcasting model

from jax import config

config.update("jax_enable_x64", True)
from ssm4epi.models.hospitalization import (
    hospitalization_model,
    dates,
    unique_a,
    h_by_age,
    I_by_age,
)
from ssm4epi.models.util import from_consecutive_logits
from isssm.importance_sampling import mc_integration
from isssm.kalman import state_mode
from jax import vmap
from ssm4epi.patch import full_deps
import jax.numpy as jnp
from pyprojroot import here
import matplotlib.pyplot as plt
from isssm.estimation import initial_theta, mle_pgssm
from isssm.laplace_approximation import laplace_approximation as LA
from isssm.modified_efficient_importance_sampling import (
    modified_efficient_importance_sampling as MEIS,
)
from isssm.importance_sampling import pgssm_importance_sampling, ess_pct
import jax.random as jrn
import pandas as pd
import fastcore.test as fct
from tqdm.notebook import tqdm
from typing import NamedTuple
from jaxtyping import Float, Array, PRNGKeyArray
from scipy.optimize import OptimizeResult
from isssm.typing import GLSSMProposal
from datetime import date

from ssm4epi.models.hospitalization import hospitalization_data

We focus on only nowcasting for the 00+ age group.

Hospitalization and incidence data

hosp_data = hospitalization_data[hospitalization_data.a == "00+"]
h = hosp_data.pivot(index="s", columns="k", values="h").to_numpy()
I = hosp_data.pivot(index="s", columns="k", values="I").to_numpy()[:, 0]
dates = pd.to_datetime(hospitalization_data["s"].unique(), format="%Y-%m-%d").date
plt.figure(figsize=(16, 5))
plt.plot(pd.to_datetime(dates, format="%Y-%m-%d"), h)
plt.legend([f"Delay {i}" for i in range(h.shape[1])])
plt.show()

plt.figure(figsize=(16, 5))
plt.plot(pd.to_datetime(dates, format="%Y-%m-%d"), I)
plt.legend([f"Incidence"])
plt.show()

Data structures

class PredictionResult(NamedTuple):
    mean: Float[Array, "..."]
    sd: Float[Array, "..."]
    quantiles: Float[Array, "k ..."]
    quantiles_of_interest: Float[Array, "k"]


class HospitalizationNowcastingResult(NamedTuple):
    a_index: int
    h: Float[Array, "..."]
    I: Float[Array, "..."]
    dates: pd.DatetimeIndex
    y_miss: Float[Array, "np1 p"]
    y_true: Float[Array, "np1 p"]
    theta0: Float[Array, "5"]
    theta0_result: OptimizeResult
    proposal: GLSSMProposal
    key: PRNGKeyArray
    ess_pct: Float
    y_sum_predict: PredictionResult
from ssm4epi.models.hospitalization import estimate_theta0_missing, LA_missing
  • Parameter estimation
    • only for fully available observations
  • LA for full missing model
  • nowcasting
from jaxtyping import PRNGKeyArray


def find_index_of_date(dates: pd.DatetimeIndex, target_date: date) -> int:
    ((where,),) = jnp.where(dates == target_date)
    return int(where)
fct.test_eq(find_index_of_date(dates, dates[0]), 0)
fct.test_eq(find_index_of_date(dates, dates[-1]), len(dates) - 1)
class Configuration(NamedTuple):
    dates: pd.DatetimeIndex
    n_delay: int  # number of delays
    key: PRNGKeyArray
    h: Float[Array, "n n_delay"]
    I: Float[Array, "n"]
    n: int
    n_weekday: int = 2  # number of weekday effects

    @classmethod
    def from_dates(cls, start: date, end: date, n_delay: int) -> "Configuration":
        i_start = find_index_of_date(dates, start)
        i_end = find_index_of_date(dates, end)
        h = hosp_data.pivot(index="s", columns="k", values="h").to_numpy()[
            i_start : (i_end + 1), :n_delay
        ]
        h = jnp.array(h)
        I = hosp_data.pivot(index="s", columns="k", values="I").to_numpy()[
            i_start : (i_end + 1), 0
        ]
        I = jnp.array(I)
        n, n_delay = h.shape

        return Configuration(
            dates=pd.date_range(start, end, freq="D"),
            n_delay=n_delay,
            key=jrn.PRNGKey(2423901241),
            h=h,
            I=I,
            n=n,
        )
start_date = date(2022, 1, 1)
end_date = date(2023, 1, 1)
config = Configuration.from_dates(start_date, end_date, n_delay=7)
fct.test_eq(config.dates[0].date(), start_date)
fct.test_eq(config.dates[-1].date(), end_date)
fct.test_eq(config.n_delay, 7)
from isssm.typing import GLSSM, PGSSM
from isssm.kalman import simulation_smoother
from functools import partial
from isssm.importance_sampling import log_weights


def pgssm_importance_sampling_missing(
    y: Float[Array, "n+1 p"],  # observations
    model: PGSSM,  # model
    z: Float[Array, "n+1 p"],  # synthetic observations
    Omega: Float[Array, "n+1 p p"],  # covariance of synthetic observations
    N: int,  # number of samples
    key: PRNGKeyArray,  # random key
) -> tuple[
    Float[Array, "N n+1 m"], Float[Array, "N"]
]:  # importance samples and weights
    u, A, D, Sigma0, Sigma, v, B, dist, xi = model

    missing_z_indices = jnp.isnan(z).any(axis=-1)
    if jnp.any(missing_z_indices):
        B = B.at[missing_z_indices].set(0.0)
        v = v.at[missing_z_indices].set(0.0)
        Omega = Omega.at[missing_z_indices].set(0.0)
        z = z.at[missing_z_indices].set(0.0)

    glssm = GLSSM(u, A, D, Sigma0, Sigma, v, B, Omega)

    key, subkey = jrn.split(key)
    s = simulation_smoother(glssm, z, N, subkey)

    model_log_weights = partial(log_weights, y=y, dist=dist, xi=xi, z=z, Omega=Omega)

    lw = vmap(model_log_weights)(s)

    return s, lw
from ssm4epi.models.hospitalization import (
    account_for_nans,
    make_y_nan,
    estimate_theta0_missing,
    LA_missing,
)
from isssm.importance_sampling import prediction


def make_theta_manual(y, I):
    p0_hat = y[:7].sum() / I[:7].sum()

    exp_theta = jnp.array(
        [
            1**2,  # s2_p
            1**2,  # s2_q
            0.1**2,  # s2_W
            0.1**2,  # s2_0
            p0_hat,  # p0
        ]
    )

    theta_manual = jnp.log(exp_theta)
    return theta_manual


percentiles_of_interest = jnp.array([0.025, 0.1, 0.25, 0.5, 0.75, 0.9, 0.975])


class NowcastingDebuggingResult(NamedTuple):
    h: Float[Array, "n n_delay"]
    I: Float[Array, "n"]
    y_miss: Float[Array, "np1 p"]
    y_true: Float[Array, "np1 p"]
    theta_manual: Float[Array, "5"]
    theta0: Float[Array, "5"]
    theta0_result: OptimizeResult
    proposal: GLSSMProposal
    key: PRNGKeyArray
    ess_pct: Float


def nowcast_hospitalizations(config: Configuration):
    key = config.key
    h = config.h
    I = config.I

    np1, n_delay = h.shape
    n_weekday = config.n_weekday

    aux = (np1, n_delay, n_weekday, I)

    h_nan = make_y_nan(h)
    theta_manual = make_theta_manual(h_nan, I)
    theta0_result = estimate_theta0_missing(
        h_nan,
        theta_manual,
        aux=aux,
        I=I,
    )
    theta0 = theta0_result.x

    missing_y_indices = jnp.isnan(h_nan)
    missing_s_indicies = jnp.concatenate(
        (jnp.full((np1, 1), False, dtype=bool), missing_y_indices[:, :-1]), axis=-1
    )

    _, y_miss = account_for_nans(
        hospitalization_model(theta0, aux),
        h_nan,
        missing_y_indices,
        missing_s_indicies,
    )
    _model_miss = lambda theta, aux: account_for_nans(
        hospitalization_model(theta, aux), h_nan, missing_y_indices, missing_s_indicies
    )[0]
    model_miss0 = _model_miss(theta0, aux)
    proposal_la, info_la = LA_missing(y_miss, model_miss0, 100, eps=1e-10)

    nan_z_indices = jnp.isnan(proposal_la.z).any(axis=-1)
    if jnp.any(nan_z_indices):
        missing_y_indices = missing_y_indices.at[nan_z_indices].set(True)
        # where z is missing completely, s has to be missing as well
        missing_s_indicies = missing_s_indicies.at[nan_z_indices].set(True)

        _, y_miss = account_for_nans(
            hospitalization_model(theta0, aux),
            h_nan.at[nan_z_indices].set(jnp.nan),
            missing_y_indices,
            missing_s_indicies,
        )
        _model_miss = lambda theta, aux: account_for_nans(
            hospitalization_model(theta, aux),
            h_nan.at[nan_z_indices].set(jnp.nan),
            missing_y_indices,
            missing_s_indicies,
        )[0]
        model_miss0 = _model_miss(theta0, aux)
        proposal_la = proposal_la._replace(
            z=proposal_la.z.at[nan_z_indices].set(0.0),
            Omega=proposal_la.Omega.at[nan_z_indices].set(0.0),
        )

    key, subkey = jrn.split(key)
    samples, log_weights = pgssm_importance_sampling(
        y_miss,
        model_miss0,
        proposal_la.z,
        proposal_la.Omega,
        10000,
        subkey,
    )
    ess_pct_nowcast = ess_pct(log_weights)

    def f_nowcast(x, s, y):
        return jnp.sum(
            (missing_y_indices * y) + (1 - missing_y_indices) * y_miss, axis=-1
        )

    key, subkey = jrn.split(key)
    preds = prediction(
        f_nowcast,
        y_miss,
        proposal_la,
        _model_miss(theta0, aux),
        10000,
        subkey,
        percentiles_of_interest,
        hospitalization_model(theta0, aux),
    )
    debugging_result = NowcastingDebuggingResult(
        h=h,
        I=I,
        y_miss=y_miss,
        y_true=h_nan,
        theta_manual=theta_manual,
        theta0=theta0,
        theta0_result=theta0_result,
        proposal=proposal_la,
        key=key,
        ess_pct=ess_pct_nowcast,
    )
    return preds, debugging_result
from datetime import timedelta

date_start = date(2021, 11, 23)
n_days_back = 100

date_end = date(2022, 4, 29)
configs = [
    Configuration.from_dates(
        start=end - timedelta(days=n_days_back),
        end=end,
        n_delay=7,
    )._replace(n_weekday=0)
    for end in pd.date_range(date_start, date_end, freq="D").date
]

results = [nowcast_hospitalizations(config) for config in tqdm(configs)]
The Kernel crashed while executing code in the current cell or a previous cell. 

Please review the code in the cell(s) to identify a possible cause of the failure. 

Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. 

View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.

Diagnostics

nan_dates = [
    str(date_start + timedelta(days=i))
    for i, (result) in enumerate(results)
    if jnp.isnan(result[0][-1])
]
REL_DIFF_THRESHOLD = 0.3
rel_diff_big_dates = [
    str(date_start + timedelta(days=i))
    for i, (result, config) in enumerate(zip(results, configs))
    if jnp.abs(result[0][-1] - config.h.sum(axis=-1)[-1]) / config.h.sum(axis=-1)[-1]
    > REL_DIFF_THRESHOLD
]

print(f"Dates with NaN results: {nan_dates}")
print(f"Dates with big relative differences: {len(rel_diff_big_dates)}")
Dates with NaN results: ['2022-08-11']
Dates with big relative differences: 30
# post-processing
# for nan results, rerun with configuration with n_weekday=0

for i, (config, result) in tqdm(enumerate(zip(configs, results))):
    true_h = config.h[-1].sum()
    predicted_mean_h = result[0][-1]
    abs_rel_diff = jnp.abs((predicted_mean_h - true_h) / true_h)
    if abs_rel_diff > REL_DIFF_THRESHOLD:
        print(f"Rerunning for {config.dates[-1]} with n_weekday=0")
        config = config._replace(n_weekday=0)
        result, debugging_info = nowcast_hospitalizations(config)
        results[i] = result

        predicted_mean_h = result[0][-1]
        abs_rel_diff = jnp.abs((predicted_mean_h - true_h) / true_h)
        print(
            f"New result for {config.dates[-1]}: {predicted_mean_h}, true h: {true_h}, rel diff: {abs_rel_diff}"
        )
Rerunning for 2021-12-02 00:00:00 with n_weekday=0
New result for 2021-12-02 00:00:00: 9819.808576317319, true h: 10232, rel diff: 0.04028454101668115
Rerunning for 2021-12-09 00:00:00 with n_weekday=0
New result for 2021-12-09 00:00:00: 11013.575017011477, true h: 10175, rel diff: 0.08241523508712305
Rerunning for 2022-02-06 00:00:00 with n_weekday=0
New result for 2022-02-06 00:00:00: 8074.097615740208, true h: 9355, rel diff: 0.13692168725385268
Rerunning for 2022-02-07 00:00:00 with n_weekday=0
New result for 2022-02-07 00:00:00: 9159.233606524445, true h: 9468, rel diff: 0.03261157514528461
Rerunning for 2022-03-06 00:00:00 with n_weekday=0
New result for 2022-03-06 00:00:00: 9185.286094649857, true h: 9172, rel diff: 0.001448549351270906
Rerunning for 2022-03-24 00:00:00 with n_weekday=0
New result for 2022-03-24 00:00:00: 11125.70464518987, true h: 11529, rel diff: 0.03498094846128289
Rerunning for 2022-03-29 00:00:00 with n_weekday=0
New result for 2022-03-29 00:00:00: 11811.190008674002, true h: 11443, rel diff: 0.03217600355448767
Rerunning for 2022-03-30 00:00:00 with n_weekday=0
New result for 2022-03-30 00:00:00: 11898.057212203674, true h: 11337, rel diff: 0.049489036976596436
Rerunning for 2022-04-18 00:00:00 with n_weekday=0
New result for 2022-04-18 00:00:00: 5534.035337901045, true h: 7452, rel diff: 0.2573758269053885
Rerunning for 2022-05-10 00:00:00 with n_weekday=0
New result for 2022-05-10 00:00:00: 6147.674543719715, true h: 5177, rel diff: 0.18749749733817173
Rerunning for 2022-05-11 00:00:00 with n_weekday=0
New result for 2022-05-11 00:00:00: 5483.213774518353, true h: 5042, rel diff: 0.08750769030510763
Rerunning for 2022-05-13 00:00:00 with n_weekday=0
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[84], line 11
      9 print(f"Rerunning for {config.dates[-1]} with n_weekday=0")
     10 config = config._replace(n_weekday=0)
---> 11 result, debugging_info = nowcast_hospitalizations(config)
     12 results[i] = result
     14 predicted_mean_h = result[0][-1]

Cell In[43], line 121, in nowcast_hospitalizations(config)
    116     return jnp.sum(
    117         (missing_y_indices * y) + (1 - missing_y_indices) * y_miss, axis=-1
    118     )
    120 key, subkey = jrn.split(key)
--> 121 preds = prediction(
    122     f_nowcast,
    123     y_miss,
    124     proposal_la,
    125     _model_miss(theta0, aux),
    126     10000,
    127     subkey,
    128     percentiles_of_interest,
    129     hospitalization_model(theta0, aux),
    130 )
    131 debugging_result = NowcastingDebuggingResult(
    132     h=h,
    133     I=I,
   (...)
    141     ess_pct=ess_pct_nowcast,
    142 )
    143 return preds, debugging_result

File ~/workspace/work/packages/isssm/isssm/importance_sampling.py:292, in prediction(f, y, proposal, model, N, key, probs, prediction_model)
    288     percentiles = prediction_percentiles(
    289         f_samples, normalize_weights(log_weights), probs
    290     )
    291 elif f_samples.ndim == 2:
--> 292     percentiles = vmap(_prediction_percentiles, (1, None, None), 1)(
    293         f_samples, normalize_weights(log_weights), probs
    294     )
    295 elif f_samples.ndim == 1:
    296     percentiles = _prediction_percentiles(
    297         f_samples, normalize_weights(log_weights), probs
    298     )

    [... skipping hidden 1 frame]

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/api.py:1127, in vmap.<locals>.vmap_f(*args, **kwargs)
   1124 try:
   1125   axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name,
   1126                                 explicit_mesh_axis)
-> 1127   out_flat = batching.batch(
   1128       flat_fun, axis_data, in_axes_flat,
   1129       lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
   1130   ).call_wrapped(*args_flat)
   1131 except batching.SpecMatchError as e:
   1132   out_axes_flat = flatten_axes("vmap out_axes", out_tree(), out_axes)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py:211, in WrappedFun.call_wrapped(self, *args, **kwargs)
    209 def call_wrapped(self, *args, **kwargs):
    210   """Calls the transformed function"""
--> 211   return self.f_transformed(*args, **kwargs)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:609, in _batch_outer(f, axis_data, in_dims, *in_vals)
    607 tag = TraceTag()
    608 with source_info_util.transform_name_stack('vmap'):
--> 609   outs, trace = f(tag, in_dims, *in_vals)
    610 with core.ensure_no_leaks(trace): del trace
    611 return outs

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:625, in _batch_inner(f, axis_data, out_dim_dests, tag, in_dims, *in_vals)
    621   in_tracers = map(partial(to_elt, trace, idx), in_vals, in_dims)
    622 with (core.set_current_trace(trace),
    623       core.extend_axis_env_nd([(axis_data.name, axis_data.size)]),
    624       core.add_spmd_axis_names(axis_data.spmd_name)):
--> 625   outs = f(*in_tracers)
    626   out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
    627   out_vals = map(partial(from_elt, trace, axis_data.size, axis_data.explicit_mesh_axis),
    628                  range(len(outs)), outs, out_dim_dests)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:340, in flatten_fun_for_vmap(f, store, in_tree, *args_flat)
    336 @lu.transformation_with_aux2
    337 def flatten_fun_for_vmap(f: Callable,
    338                          store: lu.Store, in_tree: PyTreeDef, *args_flat):
    339   py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
--> 340   ans = f(*py_args, **py_kwargs)
    341   ans, out_tree = tree_flatten(ans, is_leaf=is_vmappable)
    342   store.store(out_tree)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/linear_util.py:402, in _get_result_paths_thunk(_fun, _store, *args, **kwargs)
    400 @transformation_with_aux2
    401 def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
--> 402   ans = _fun(*args, **kwargs)
    403   result_paths = tuple(f"result{_clean_keystr_arg_names(path)}" for path, _ in generate_key_paths(ans))
    404   if _store:
    405     # In some instances a lu.WrappedFun is called multiple times, e.g.,
    406     # the bwd function in a custom_vjp

File ~/workspace/work/packages/isssm/isssm/importance_sampling.py:175, in _prediction_percentiles(Y, weights, probs)
    170 # find indices of cumulative sum closest to probs
    171 # take corresponding Y_sorted values
    172 # with linear interpolation if necessary
    174 indices = jnp.searchsorted(cumsum, probs)
--> 175 indices = jnp.clip(indices, 1, len(Y_sorted) - 1)
    176 left_indices = indices - 1
    177 right_indices = indices

    [... skipping hidden 1 frame]

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:334, in _cpp_pjit.<locals>.cache_miss(*args, **kwargs)
    329 if config.no_tracing.value:
    330   raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for "
    331                      "`jit`, but 'no_tracing' is set")
    333 (outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked, box_data,
--> 334  executable, pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
    336 maybe_fastpath_data = _get_fastpath_data(
    337     executable, out_tree, args_flat, out_flat, attrs_tracked, box_data,
    338     jaxpr.effects, jaxpr.consts, jit_info.abstracted_axes, pgle_profiler)
    340 return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:197, in _python_pjit_helper(fun, jit_info, *args, **kwargs)
    195   out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
    196 else:
--> 197   out_flat = pjit_p.bind(*args_flat, **p.params)
    198   compiled = None
    199   profiler = None

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:531, in Primitive.bind(self, *args, **params)
    529 def bind(self, *args, **params):
    530   args = args if self.skip_canonicalization else map(canonicalize_value, args)
--> 531   return self._true_bind(*args, **params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:551, in Primitive._true_bind(self, *args, **params)
    549 trace_ctx.set_trace(eval_trace)
    550 try:
--> 551   return self.bind_with_trace(prev_trace, args, params)
    552 finally:
    553   trace_ctx.set_trace(prev_trace)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:556, in Primitive.bind_with_trace(self, trace, args, params)
    555 def bind_with_trace(self, trace, args, params):
--> 556   return trace.process_primitive(self, args, params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/interpreters/batching.py:496, in BatchTrace.process_primitive(self, p, tracers, params)
    494   else:
    495     with core.set_current_trace(self.parent_trace):
--> 496       val_out, dim_out = fancy_primitive_batchers[p](
    497           self.axis_data, vals_in, dims_in, **params)
    498 elif args_not_mapped:
    499   # no-op shortcut
    500   return p.bind_with_trace(self.parent_trace, vals_in, params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:2172, in _pjit_batcher(axis_data, vals_in, dims_in, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs)
   2167 if not (all(l is None for l in in_layouts) and
   2168         all(l is None for l in out_layouts)):
   2169   raise NotImplementedError(
   2170       'Concrete layouts are not supported for vmap(jit).')
-> 2172 vals_out = pjit_p.bind(
   2173   *vals_in,
   2174   jaxpr=new_jaxpr,
   2175   in_shardings=in_shardings,
   2176   out_shardings=out_shardings,
   2177   in_layouts=in_layouts,
   2178   out_layouts=out_layouts,
   2179   donated_invars=donated_invars,
   2180   ctx_mesh=ctx_mesh,
   2181   name=name,
   2182   keep_unused=keep_unused,
   2183   inline=inline,
   2184   compiler_options_kvs=compiler_options_kvs)
   2186 resolved_axes_out = batching.resolve_ragged_axes_against_inputs_outputs(
   2187     vals_in, vals_out, axes_out)
   2188 return vals_out, resolved_axes_out

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:531, in Primitive.bind(self, *args, **params)
    529 def bind(self, *args, **params):
    530   args = args if self.skip_canonicalization else map(canonicalize_value, args)
--> 531   return self._true_bind(*args, **params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:551, in Primitive._true_bind(self, *args, **params)
    549 trace_ctx.set_trace(eval_trace)
    550 try:
--> 551   return self.bind_with_trace(prev_trace, args, params)
    552 finally:
    553   trace_ctx.set_trace(prev_trace)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:556, in Primitive.bind_with_trace(self, trace, args, params)
    555 def bind_with_trace(self, trace, args, params):
--> 556   return trace.process_primitive(self, args, params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/core.py:1060, in EvalTrace.process_primitive(self, primitive, args, params)
   1058 args = map(full_lower, args)
   1059 check_eval_args(args)
-> 1060 return primitive.impl(*args, **params)

File ~/workspace/work/phd/thesis/.venv/lib/python3.10/site-packages/jax/_src/pjit.py:1928, in _pjit_call_impl(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, donated_invars, ctx_mesh, name, keep_unused, inline, compiler_options_kvs, *args)
   1920 donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
   1921 cache_key = pxla.JitGlobalCppCacheKeys(
   1922     donate_argnums=donated_argnums, donate_argnames=None,
   1923     device=None, backend=None,
   (...)
   1926     in_layouts_treedef=None, in_layouts_leaves=in_layouts,
   1927     out_layouts_treedef=None, out_layouts_leaves=out_layouts)
-> 1928 return xc._xla.pjit(
   1929     name, f, call_impl_cache_miss, [], [], cache_key,
   1930     tree_util.dispatch_registry, pxla.cc_shard_arg,
   1931     _get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args)

KeyboardInterrupt: 
date_to_investigate = date(2021, 12, 4)

config = Configuration.from_dates(
    start=date_to_investigate - timedelta(days=n_days_back),
    end=date_to_investigate,
    n_delay=7,
)
config = config._replace(n_weekday=0)
result, debugging_info = nowcast_hospitalizations(config)
result[0][-1], result[1][-1], result[2][:, -1]
(Array(10214.00893297, dtype=float64),
 Array(680.32175682, dtype=float64),
 Array([ 8307.,  9540.,  9574., 10636., 10646., 10646., 10860.], dtype=float64))
from isssm.laplace_approximation import posterior_mode

plt.plot(posterior_mode(debugging_info.proposal))

plt.plot(result[0])
plt.plot(config.h.sum(axis=-1))
plt.plot(result[2][0], linestyle="--", color="gray")
plt.plot(result[2][-1], linestyle="--", color="gray")
plt.show()

Prepare output

from ssm4epi.models.util import result_to_series
df = pd.concat(
    [
        result_to_series(result, percentiles_of_interest).to_frame().T.assign(date=date)
        for (result), date in zip(
            results, pd.date_range(date_start, date_end, freq="D").date
        )
    ]
).reset_index(drop=True)
df
mean sd 2.5 % 10.0 % 25.0 % 50.0 % 75.0 % 90.0 % 97.5 % date
0 9924.819221 217.768947 9400.000000 9646.000000 9783.000000 9962.0 10133.00000 10184.0 10206.000000 2021-11-23
1 10380.228649 452.012002 9449.000000 10002.000000 10236.000000 10357.0 10377.00000 11270.0 11454.000000 2021-11-24
2 9938.245825 565.993988 9622.000000 9622.000000 9622.000000 9741.0 9741.00000 10975.0 11307.000000 2021-11-25
3 10287.936399 663.933208 9421.000000 9611.000000 9664.186002 10134.0 10688.00000 11236.0 11804.000000 2021-11-26
4 9941.978967 789.978665 8891.373377 8953.000000 9307.000000 9836.0 10455.00000 11315.0 11376.000000 2021-11-27
... ... ... ... ... ... ... ... ... ... ...
400 9157.216763 379.259969 8610.000000 8685.000000 8791.000000 9183.0 9245.00000 9546.0 10058.000000 2022-12-28
401 7629.000000 0.000000 7629.000000 7629.000000 7629.000000 7629.0 7629.00000 7629.0 7629.000000 2022-12-29
402 9160.083345 269.441754 8556.000000 8737.000000 9047.000000 9217.0 9345.00000 9391.0 9625.275487 2022-12-30
403 8442.505495 358.689895 7887.000000 8053.059233 8149.000000 8441.0 8734.00000 8923.0 9264.000000 2022-12-31
404 10767.918527 976.223671 8984.000000 9309.000000 9957.000000 10654.0 11773.07568 11937.0 11970.000000 2023-01-01

405 rows × 10 columns

plt.plot(jnp.array([result[2][[0, 3, 6], -1] for result in results]))

plt.plot(
    h[
        find_index_of_date(dates, date_start) : find_index_of_date(
            dates, date_end + timedelta(days=1)
        )
    ].sum(axis=1),
)
plt.ylim(0, 20000)
plt.show()

df.to_csv(here("data/results/4_hospitalizations/nowcast/nowcast.csv"), index=False)